library(tidyverse)
library(data.table)
library(igraph)
library(tidygraph)
library(Matrix)
library(bigalgebra)
library(ergm)

set.seed(12345)

## Simulation study v2

## 1) Load in a network -- https://github.com/briatte/congress
## 2) Use ergm functions to model the network
##    This includes triangles, homophily, whatever you want
## 3) Draw from the network
## 4) Drop some node ideologies
## 5) Use estimation functions to predict the ideology
## 6) Check the accuracy

#### Now for real ####
load("data/net_us.rda")
gdata::keep(net_us_hr117, sure=T)
source('code/estimation_functions.R')


net_us_hr117 %v% 'color' = 'grey25'
net_us_hr117 %v% 'color' = ifelse(net_us_hr117 %v% 'party' == "REP", "firebrick", net_us_hr117 %v% 'color')
net_us_hr117 %v% 'color' = ifelse(net_us_hr117 %v% 'party' == "DEM", "royalblue", net_us_hr117 %v% 'color')
net_us_hr117 %v% 'color' = ifelse(net_us_hr117 %v% 'party' == "IND", "forestgreen", net_us_hr117 %v% 'color')


## Add in NOMINATE
dwnom =read.csv("data/Hall_members.csv") %>% select(bioguide_id, nominate_dim1)

walk = data.frame(bg_network = net_us_hr117 %v% 'bioguide') %>%
  mutate(id = 1:n()) %>%
  left_join( dwnom, by=c("bg_network"="bioguide_id")) %>%
  arrange(id) %>%
  filter(!duplicated(bg_network))

walk[is.na(walk$nominate_dim1),] 

walk$nominate_dim1[is.na(walk$nominate_dim1)] = 0

net_us_hr117 %v% 'dwnom' = walk$nominate_dim1


## Trim weak edges?
#delete.edges(n, which(net_us_hr117 %e% "gsw" < .45))

#plot(net_us_hr117,
#     vertex.col= net_us_hr117 %v% 'color',
#     vertex.border = net_us_hr117 %v% 'color',
#     edge.col = "grey50")

m_hr117 <- ergm(net_us_hr117 ~edges + absdiff('dwnom') + nodefactor('party') + nodematch('party'))

## NOTE: you may optionally uncomment these lines and comment the one above to use parallel threads
#m_hr117 <- ergm(net_us_hr117 ~edges + absdiff('dwnom') + nodefactor('party') + nodematch('party'),
#                control = control.ergm(parallel=4, parallel.type = "PSOCK"))


hr_sim = function(p){
  sim1 = simulate(m_hr117)
  #plot(sim1,
  #     vertex.col= net_us_hr117 %v% 'color',
  #     vertex.border = net_us_hr117 %v% 'color',
  #     edge.col = "grey50")
  
  
  truth_df = data.frame(name = network::network.vertex.names(sim1),
                        truth = sim1 %v% 'dwnom',
                        party = sim1 %v% 'party')
  
  
  ## Drop some nodes
  ## Drop ALL the briefs
  drops = c(rbinom(length(sim1 %v% 'party'), 1, prob=p))
  sim1%v%'dwnom' = ifelse(drops, sim1%v%'dwnom', NA) 
  sim1%v%'weight' = 1 ## do I need this?
  
  truth_df$drops = drops
  
  ## Convert
  g2 = intergraph::asIgraph(sim1)
  g2 = g2 %>% set_edge_attr("weight", value=1)
  V(g2)$name = network::network.vertex.names(sim1)
  g3 = as_tbl_graph(g2)
  
  
  ## Do the estimation
  idealvec = sim1%v%'dwnom'
  
  M = make_transition_matrix(g3,V(g3)$name[as.logical(drops)]) 
  test_scores = estimate_ideology_without_steady_state(M,
                                                       exog = V(g3)$name[as.logical(drops)],
                                                       score_vector = idealvec[as.logical(drops)])
  test_scores = test_scores[,1]
  summary(test_scores)
  
  result_df = data.frame(name = names(test_scores), pred = test_scores)
  result_df2 = merge(truth_df, result_df)
  
  cor(result_df2$pred,result_df2$truth)
  mad(result_df2$truth, result_df2$pred)
  
  out = result_df2 %>% filter(drops == 0) %>%
    #group_by(party) %>%
    summarize(corr = cor(pred, truth),
              MAD = mad(pred, truth),
              samesign = mean(sign(pred)==sign(truth))) %>%
    mutate(dropprob = p)
  return(out)
  }
  
set.seed(12345)
params = rep(seq(0.05, 0.95, by=0.05), each=20)
  
out = list()
for(i in 1:length(params)){
  out[[i]] = hr_sim(p = params[i])
  print(i)
}

out2 = do.call(rbind, out) %>% filter(dropprob<1)
save(out2, file="data/hr_sims.RData")